import torch
import torch.nn as nn


class kl_loss(nn.Module):
    def __init__(self):
        super(kl_loss, self).__init__()

    """This function computes the Kullback-Leibler divergence between ground
               truth saliency maps and their predictions. Values are first divided by
               their sum for each image to yield a distribution that adds to 1.

    Args:
        y_true (tensor, float32): A 4d tensor that holds the ground truth
                                  saliency maps with values between 0 and 255.
        y_pred (tensor, float32): A 4d tensor that holds the predicted saliency
                                  maps with values between 0 and 1.
        eps (scalar, float, optional): A small factor to avoid numerical
                                       instabilities. Defaults to 1e-7.

    Returns:
        tensor, float32: A 0D tensor that holds the averaged error.
    """
    @staticmethod
    def forward(y_true, y_pred, eps=1e-7):
        sum_per_image = torch.sum(y_true, dim=(1, 2, 3), keepdim=True)
        y_true_normal = y_true / (eps + sum_per_image)

        sum_per_image = torch.sum(y_pred, dim=(1, 2, 3), keepdim=True)
        y_pred_normal = y_pred / (eps + sum_per_image)

        loss = y_true_normal * torch.log(eps + y_true_normal / (eps + y_pred_normal))

        return torch.mean(torch.sum(loss, dim=(1, 2, 3)))
